import torch

class Trajectories:
    """
    Class to store trajectories and their associated log-probabilities.
    """

    def __init__(self, states=None, actions=None, log_rewards=None, device=None):
        self.states = states
        self.actions = actions
        self.log_rewards = log_rewards
        self.device = device if device else (states.device if states is not None else None)
        self.batch_size, self.length, self.dim = self._infer_dimensions(states)
        self.initial_states = states[:, 0, :] if states is not None else None
        self.logPF = self.logPB = self.log_fullPF = self.log_fullPB = self.logF = None

    def _infer_dimensions(self, states):
        if states is not None:
            batch_size, length, dim = states.shape
            return batch_size, length - 1, dim - 1
        return None, None, None
    
    def compute_logPF(self, gfn, head_index=None):
        """
        Compute the log probabilities according to the forward policy.
        """
        self.logPF, self.log_fullPF = torch.zeros((self.batch_size,), device=gfn.device), torch.zeros((self.batch_size, self.length), device=gfn.device)
        for t in range(self.length):
            policy_dist = gfn.get_forward_policy_dist(self.states[:, t, :], head_index)
            log_prob = policy_dist.log_prob(self.actions[:, t, :])
            self.logPF += log_prob
            self.log_fullPF[:, t] = log_prob

    def compute_logPB(self, gfn):
        """
        Compute the log probabilities according to the backward policy.
        """
        self.logPB, self.log_fullPB = torch.zeros((self.batch_size,), device=gfn.device), torch.zeros((self.batch_size, self.length), device=gfn.device)
        for t in range(self.length, 1, -1):
            policy_dist = gfn.get_backward_policy_dist(self.states[:, t, :])
            log_prob = policy_dist.log_prob(self.actions[:, t - 1, :])
            self.logPB += log_prob
            self.log_fullPB[:, t - 1] = log_prob

    def compute_logF(self, gfn):
        """
        Compute the flow according to the flow model.
        """
        self.logF = torch.stack([gfn.logF_model(self.states[:, t, :]).squeeze() for t in range(self.length)], dim=1)

    def set_log_probabilities(self, logP, log_fullP, forward=True):
        assert logP.shape == (self.batch_size,)
        assert log_fullP.shape == (self.batch_size, self.length)
        if forward:
            self.logPF, self.log_fullPF = logP, log_fullP
        else:
            self.logPB, self.log_fullPB = logP, log_fullP

    def get_final_states(self):
        return self.states[:, -1, :-1]

    def trajectory_iterator(self):
        for i in range(self.batch_size):
            yield self.states[i], self.log_rewards[i]

    def prune(self, log_reward_threshold):
        """
        Removes all trajectories with a log-reward below the threshold.
        """
        mask = self.log_rewards > log_reward_threshold
        for attr in ['states', 'actions', 'log_rewards', 'initial_states', 'logPF', 'logPB', 'log_fullPF', 'log_fullPB', 'logF']:
            attribute = getattr(self, attr)
            if attribute is not None:
                setattr(self, attr, attribute[mask])
        self.batch_size = self.states.shape[0]

    def slice_trajs(self, indices):
        """
        Returns a new Trajectories object with only the trajectories at the given indices.
        """
        sliced_traj = Trajectories(states=self.states[indices], log_rewards=self.log_rewards[indices], device=self.device)
        for attr in ['logPF', 'logPB', 'log_fullPF', 'log_fullPB', 'logF']:
            attribute = getattr(self, attr)
            if attribute is not None:
                setattr(sliced_traj, attr, attribute[indices])
        return sliced_traj
    
    def detach(self):
        """
        Detach all tensors from the computational graph.
        """
        for attr in ['states', 'actions', 'log_rewards', 'initial_states', 'logPF', 'logPB', 'log_fullPF', 'log_fullPB', 'logF']:
            attribute = getattr(self, attr)
            if attribute is not None:
                setattr(self, attr, attribute.detach())


class TrajectoryBuffer(Trajectories):
    """
    An extendable buffer of trajectories, used for the replay buffer. 
    """

    def __init__(self, capacity, trajectory_length, dim, batch_size, device, reward_only=False):
        super().__init__()
        self.device = device
        self.reward_only = reward_only
        self.capacity = capacity
        self.batch_size = batch_size
        self.states = torch.zeros((capacity, trajectory_length + 1, dim + 1), device=self.device)
        self.log_rewards = torch.full((capacity,), float('-inf'), device=self.device, dtype=torch.float)
        if not self.reward_only:
            self.logPF = torch.zeros((capacity,), device=self.device)
            self.logPB = torch.zeros((capacity,), device=self.device)
            self.log_fullPF = torch.zeros((capacity, trajectory_length), device=self.device)
            self.log_fullPB = torch.zeros((capacity, trajectory_length), device=self.device)
            self.logF = torch.zeros((capacity, trajectory_length), device=self.device)
        self.length = trajectory_length
        self.dim = dim
        self.stored_trajectories = 0
        self.writing_index = 0

    def extend(self, other_trajectories):
        """
        Extend the current trajectories with other_trajectories.
        """
        writing_indices = torch.arange(self.writing_index, self.writing_index + other_trajectories.batch_size) % self.capacity
        other_trajectories.detach()
        self.states[writing_indices, :, :] = other_trajectories.states.to(self.states.dtype)
        self.log_rewards[writing_indices] = other_trajectories.log_rewards.to(self.log_rewards.dtype)
        if not self.reward_only:
            self.logPF[writing_indices] = other_trajectories.logPF.to(self.logPF.dtype)
            self.logPB[writing_indices] = other_trajectories.logPB.to(self.logPB.dtype)
            self.log_fullPF[writing_indices, :] = other_trajectories.log_fullPF.to(self.log_fullPF.dtype)
            self.log_fullPB[writing_indices, :] = other_trajectories.log_fullPB.to(self.log_fullPB.dtype)
            if other_trajectories.logF is not None:
                self.logF[writing_indices, :] = other_trajectories.logF.to(self.logF.dtype)

        self.stored_trajectories = min(self.stored_trajectories + other_trajectories.batch_size, self.capacity)
        self.writing_index = (self.writing_index + other_trajectories.batch_size) % self.capacity

    def __len__(self):
        return self.stored_trajectories